package com.yinsin.jpabatis.config;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;

import javax.annotation.PostConstruct;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Component;

import com.yinsin.jpabatis.annotations.JpaMapper;
import com.yinsin.jpabatis.exceptions.JpaBatisException;
import com.yinsin.jpabatis.mapper.ClassPathMapperScanner;
import com.yinsin.jpabatis.session.JpaSession;
import com.yinsin.jpabatis.util.AopTargetUtils;

@Component
@ConditionalOnClass(JpaRepository.class)
public class ContextRefreshedListener implements ApplicationListener<ContextRefreshedEvent> {
	
	@PersistenceContext
	private EntityManager em;

	@Autowired
	private JpaBatisProperties jpaBatisProperties;

	private com.yinsin.jpabatis.config.Configuration configuration;

	@PostConstruct
	private void xmlMapperScan() {
		try {
			configuration = new com.yinsin.jpabatis.config.Configuration();
			configuration.setTransaction(em);
			new ClassPathMapperScanner(jpaBatisProperties, configuration).parseMapper();
		} catch (Exception e) {
			throw new JpaBatisException(e);
		}
	}

	@Override
	public void onApplicationEvent(ContextRefreshedEvent paramE) {
		String[] names = paramE.getApplicationContext().getBeanDefinitionNames();
		for (String beanName : names) {
			try {
				injection(paramE.getApplicationContext().getBean(beanName));
			} catch (Exception e) {
				throw new JpaBatisException(e);
			}
		}
	}

	private void injection(Object proxy) throws Exception {
		if (null != proxy) {
			Object service = AopTargetUtils.getTarget(proxy);
			JpaSession session;
			Map<String, JpaSession> sessionMap = new HashMap<String, JpaSession>();
			// 查找字段中含有依赖注入的字段 存在就进行注入
			Field[] fields = service.getClass().getDeclaredFields();
			if (null != fields && fields.length > 0) {
				for (Field field : fields) {
					JpaMapper mapper = field.getAnnotation(JpaMapper.class);
					if (null != mapper && field.getType().getName().equals(JpaSession.class.getName())) {
						if (!field.isAccessible())
							field.setAccessible(true);
						try {
							String nameKey = mapper.value();
							if (null == nameKey || nameKey.isEmpty()) {
								nameKey = "_JpaMapper";
							}
							if (sessionMap.containsKey(nameKey)) {
								session = sessionMap.get(nameKey);
							} else {
								session = new JpaSession(configuration);
								sessionMap.put(nameKey, session);
							}
							field.set(service, session);
						} catch (IllegalAccessException e) {
							throw new RuntimeException("injection " + field.getName() + " error by " + service.getClass().getName(), e);
						}
					}
				}
			}
			// 扫描有注解的方法
			Method[] methods = service.getClass().getMethods();
			if (null != methods && methods.length > 0) {
				Class<?>[] typeClassArr = null;
				for (Method method : methods) {
					JpaMapper mapper = method.getAnnotation(JpaMapper.class);
					typeClassArr = method.getParameterTypes();
					if (null == typeClassArr || typeClassArr.length > 1) {
						continue;
					}
					if (null != mapper && typeClassArr[0].getName().equals(JpaSession.class.getName())) {
						if (!method.isAccessible())
							method.setAccessible(true);
						try {
							String nameKey = mapper.value();
							if (null == nameKey || nameKey.isEmpty()) {
								nameKey = "_JpaMapper";
							}
							if (sessionMap.containsKey(nameKey)) {
								session = sessionMap.get(nameKey);
							} else {
								session = new JpaSession(configuration);
								sessionMap.put(nameKey, session);
							}
							method.invoke(service, session);
						} catch (IllegalAccessException e) {
							throw new RuntimeException("injection " + method.getName() + " error by " + service.getClass().getName(), e);
						}
					}
				}
			}
		}
	}

}
